#include "LuaHookProtocol.h"
#include "lhook.h"
#include "lua_3rd.h"
#include "SheepsMain.h"

#define CONNECT_OPEN_HOOK_FUNC "proxy_connect_open_hook"
#define CONNECT_FAILE_HOOK_FUNC "proxy_connect_faile_hook"
#define CONNECT_CLOSE_HOOK_FUNC "proxy_connect_close_hook"
#define CONNECT_SEND_HOOK_FUNC "proxy_connect_send_hook"
#define CONNECT_RECV_HOOK_FUNC "proxy_connect_recv_hook"

//static std::map<std::string, bool>* OldVmPackageLoaded = new std::map<std::string, bool>;
//static std::map<std::string, bool>* OldVmGlobalLoaded = new std::map<std::string, bool>;

static void global_package_loaded_init(lua_State* L, std::map<std::string, bool>& package_loaded) {
	lua_getglobal(L, "package");
	lua_pushstring(L, "loaded");
	lua_gettable(L, -2);

	lua_pushnil(L);
	while (lua_next(L, -2) != 0) {
		package_loaded.insert(std::pair<std::string, bool>(lua_tostring(L, -2), true));
		lua_pop(L, 1);
	}
	lua_pop(L, lua_gettop(L));
}

static void global_parms_loaded_init(lua_State* L, std::map<std::string, bool>& parms_loaded) {
	lua_getglobal(L, "_G");

	lua_pushnil(L);
	while (lua_next(L, -2) != 0) {
		parms_loaded.insert(std::pair<std::string, bool>(lua_tostring(L, -2), true));
		lua_pop(L, 1);
	}
	lua_pop(L, lua_gettop(L));
}

static void reset_luavm_package(lua_State* L, std::map<std::string, bool>& package_loade) {
	lua_getglobal(L, "package");
	lua_pushstring(L, "loaded");
	lua_gettable(L, -2);

	lua_pushnil(L);
	while (lua_next(L, -2) != 0) {
		const char* key = lua_tostring(L, -2);
		if (package_loade.find(key) == package_loade.end()) {
			lua_pushstring(L, key);
			lua_pushnil(L);
			lua_settable(L, -5);
		}
		lua_pop(L, 1);
	}
	lua_pop(L, lua_gettop(L));
}

static void reset_luavm_gloable(lua_State* L, std::map<std::string, bool>& parms_loaded) {
	lua_getglobal(L, "_G");

	lua_pushnil(L);
	while (lua_next(L, -2) != 0) {
		const char* key = lua_tostring(L, -2);
		if (parms_loaded.find(key) == parms_loaded.end()) {
			lua_pushstring(L, key);
			lua_pushnil(L);
			lua_settable(L, -5);
		}
		lua_pop(L, 1);
	}
	lua_pop(L, lua_gettop(L));
}

int sheeps_report(lua_State* L, int status, LuaHookProtocol* proto, const char* func, int line) {
	if (status != LUA_OK) {
		const char* msg = lua_tostring(L, -1);
		printf("sheeps_hook: %s:%d %s\n", func, line, msg);
		LOG(shooklogid, 3, "%s:%d %s\n", func, line, msg);
		lua_pop(L, 1);  /* remove message */
		lua_resetthread(L);
	}
	return status;
}

void set_global(lua_State* L, LuaHookProtocol* proto)
{
	lua_pushlightuserdata(L, proto);
	lua_setfield(L, LUA_REGISTRYINDEX, _HookProtoName);
}

void set_ppid(lua_State* L, LuaHookProtocol* proto)
{
	lua_pushlightuserdata(L, proto);
	lua_setfield(L, LUA_REGISTRYINDEX, _HookProtoPpid);
}

static lua_State* get_thread_state(LuaHookProtocol* proto) {
	std::list<lua_State*>* TLS = &proto->TLS;
	lua_State* L = NULL;
	if (!TLS->empty()) {
		L = TLS->front();
		TLS->pop_front();
	}
	else {
		lua_State* GL = proto->GL;
		L = lua_newthread(GL);
		int ref = luaL_ref(GL, LUA_REGISTRYINDEX);
		proto->TL_REF.insert(std::make_pair(L, ref));
	}
	return L;
}

void reback_thread_state(LuaHookProtocol* proto, lua_State* L) {
	proto->TLS.push_back(L);
}

LuaHookProtocol::LuaHookProtocol()
{
	GL = luaL_newstate();
	luaL_openlibs(GL);
	luaL_open_3rd_lib(GL);
	luaL_open_3rd_server_lib(GL);
	luaL_open_pasture_socket(GL);
	lua_State_Set_Package_Path(GL, "./liblua");
	lua_State_Set_Package_Path(GL, "./server");
	register_hook_api(GL);
	set_global(GL, this);
	global_package_loaded_init(GL, OldVmPackageLoaded);
	global_parms_loaded_init(GL, OldVmGlobalLoaded);
}

LuaHookProtocol::LuaHookProtocol(LuaHookProtocol* ppid)
{
	GL = luaL_newstate();
	luaL_openlibs(GL);
	luaL_open_3rd_lib(GL);
	luaL_open_3rd_server_lib(GL);
	luaL_open_pasture_socket(GL);
	lua_State_Set_Package_Path(GL, "./liblua");
	lua_State_Set_Package_Path(GL, "./server");
	register_hook_api(GL);
	set_global(GL, this);
	set_ppid(GL, ppid);
	global_package_loaded_init(GL, OldVmPackageLoaded);
	global_parms_loaded_init(GL, OldVmGlobalLoaded);
}

LuaHookProtocol::~LuaHookProtocol()
{
	std::map<void*, int>::iterator iter;
	for (iter = HandleRef.begin(); iter != HandleRef.end(); iter++) {
		//if (iter->second == 0) TimerDelete((HTIMER)iter->first);
		//if (iter->second == 1) HsocketClosed((HSOCKET)iter->first);
		if (iter->second == 2) AccepterStop((LuaAccepter*)iter->first);
	}
	for (iter = HandleRef.begin(); iter != HandleRef.end(); iter++) {
		if (iter->second == 0) TimerDelete((HTIMER)iter->first);
		if (iter->second == 1) HsocketClosed((HSOCKET)iter->first);
		//if (iter->second == 2) AccepterStop((LuaAccepter*)iter->first);
	}
	lua_close(GL);
}

void LuaHookProtocol::ConnectionAccepted(LuaAccepter* accepter, HSOCKET hsock, PROTOCOL protocol)
{
	if (HandleRef.find(accepter) == HandleRef.end()) {
		HsocketClosed(hsock);
		return;
	}

	this->HandleRef.insert(std::make_pair(hsock, 1));
	int state, nres;
	lua_State* L = accepter->L;
	if (L) {
		create_socket_content(L, hsock, SOCKET_CONNECT);
		state = lua_resume(L, NULL, 1, &nres);
		if (state == LUA_YIELD) {
			return;
		}
		else if (state != LUA_OK) {
			sheeps_report(L, state, this, __func__, __LINE__);
		}
		accepter->L = NULL;
		reback_thread_state(this, L);
	}
	else {
		accepter->listen_list.push_back(hsock);
	}
}

void LuaHookProtocol::ConnectionMade(HSOCKET hsock, PROTOCOL protocol)
{
	int state, nres;
	lua_State* L = (lua_State*)hsock->user_ptr;
	if (L) {
		create_socket_content(L, hsock, SOCKET_CONNECT);
		state = lua_resume(L, NULL, 1, &nres);
		if (state == LUA_YIELD) {
			return;
		}
		else if (state != LUA_OK) {
			sheeps_report(L, state, this, __func__, __LINE__);
		}
		hsock->user_ptr = NULL;
		reback_thread_state(this, L);
	}
}

void LuaHookProtocol::ConnectionFailed(HSOCKET hsock, int err)
{
	int state, nres;
	lua_State* L = (lua_State*)hsock->user_ptr;
	if (L) {
		lua_pushnil(L);
		state = lua_resume(L, NULL, 1, &nres);
		if (state == LUA_YIELD) {
			return;
		}
		else if (state != LUA_OK) {
			sheeps_report(L, state, this, __func__, __LINE__);
		}
		hsock->user_ptr = NULL;
		reback_thread_state(this, L);
	}
	this->HandleRef.erase(hsock);
}

void LuaHookProtocol::ConnectionClosed(HSOCKET hsock, int err)
{
	int state, nres;
	lua_State* L = (lua_State*)hsock->user_ptr;
	if (L) {
		lua_pushstring(L, "");
		state = lua_resume(L, NULL, 1, &nres);
		if (state == LUA_YIELD) {
			return;
		}
		else if (state != LUA_OK) {
			sheeps_report(L, state, this, __func__, __LINE__);
		}
		hsock->user_ptr = NULL;
		reback_thread_state(this, L);
	}
	this->HandleRef.erase(hsock);
}

void LuaHookProtocol::ConnectionRecved(HSOCKET hsock, const char* data, int len)
{
	int state, nres;
	lua_State* L = (lua_State*)hsock->user_ptr;
	if (L) {
		lua_pushlstring(L, data, len);
		hsock->offset = 0;
		state = lua_resume(L, NULL, 1, &nres);
		if (state == LUA_YIELD) {
			return;
		}
		else if (state != LUA_OK) {
			sheeps_report(L, state, this, __func__, __LINE__);
		}
		hsock->user_ptr = NULL;
		reback_thread_state(this, L);
	}
}

void LuaHookProtocol::RunScript(const char* script) {
	int state;
	int nres;
	lua_State* L = get_thread_state(this);
	state = luaL_loadfile(L, script);
	if (state != LUA_OK) {
		sheeps_report(L, state, this, __func__, __LINE__);
		return;
	}
	state = lua_resume(L, NULL, 0, &nres);
	if (state == LUA_YIELD) {
		return;
	}
	else if (state != LUA_OK) {
		sheeps_report(L, state, this, __func__, __LINE__);
	}
	reback_thread_state(this, L);
}

void LuaHookProtocol::init() {
	HTIMER timer = TimerCreate(this, NULL, 1000, 1000, [](HTIMER timer, BaseWorker* proto, void* user_data) {
		((LuaHookProtocol*)proto)->timer_out();
		});
	this->HandleRef.insert(std::make_pair(timer, 0));

	const char* script_file = config_get_string_value("server", "script", "server/init.lua");
	snprintf(ServerInfoString, sizeof(ServerInfoString), "扩展脚本：[%s]", script_file);
	this->RunScript(script_file);
}

void LuaHookProtocol::reinit() {
	std::map<void*, int>::iterator iter_href;
	for (iter_href = HandleRef.begin(); iter_href != HandleRef.end(); iter_href++) {
		if (iter_href->second == 2) {
			LuaAccepter* accepter = (LuaAccepter*)iter_href->first;
			AccepterStop(accepter);
			delete accepter;
		} 
	}
	for (iter_href = HandleRef.begin(); iter_href != HandleRef.end(); iter_href++) {
		if (iter_href->second == 0) TimerDelete((HTIMER)iter_href->first);
		if (iter_href->second == 1) HsocketClosed((HSOCKET)iter_href->first);
	}

	std::map<lua_State*, int>::iterator iter_lref;
	for (iter_lref = TL_REF.begin(); iter_lref != TL_REF.end(); iter_lref++) {
		luaL_unref(GL, LUA_REGISTRYINDEX, iter_lref->second);
	}

	HandleRef.clear();
	TL_REF.clear();
	TLS.clear();

	reset_luavm_package(GL, OldVmPackageLoaded);
	reset_luavm_gloable(GL, OldVmGlobalLoaded);

	this->init();
}

void LuaHookProtocol::time_sleep(HTIMER timer) {
	int state, nres;
	lua_State* L = (lua_State*)timer->user_data;
	if (L) {
		state = lua_resume(L, NULL, 0, &nres);
		if (state == LUA_YIELD) {
			return;
		}
		else if (state != LUA_OK) {
			sheeps_report(L, state, this, __func__, __LINE__);
		}
		reback_thread_state(this, L);
	}
}

void LuaHookProtocol::http_handler(HSOCKET hsock, const char* data, int len)
{
	HandleRef.insert(std::make_pair(hsock, 1));
	int nres = 0;
	lua_State* L = get_thread_state(this);
	lua_getglobal(L, __func__);
	create_socket_content(L, hsock, SOCKET_CONNECT);
	lua_pushlstring(L, data, len);
	int state = lua_resume(L, NULL, 2, &nres);
	if (state == LUA_YIELD) {
		return;
	}
	else if (state != LUA_OK) {
		sheeps_report(L, state, this, __func__, __LINE__);
	}
	reback_thread_state(this, L);
}

void LuaHookProtocol::event_handler(cJSON* root)
{
	int nres = 0;
	lua_State* L = get_thread_state(this);
	lua_getglobal(L, __func__);
	int state = lua_resume(L, NULL, 0, &nres);
	if (state == LUA_YIELD) {
		return;
	}
	else if (state != LUA_OK) {
		sheeps_report(L, state, this, __func__, __LINE__);
	}
	reback_thread_state(this, L);
}

void LuaHookProtocol::event_task(ServerTaskConfig* taskcfg, const char* event) {
	int nres = 0;
	lua_State* L = get_thread_state(this);
	lua_getglobal(L, "event_handler");
	lua_pushstring(L, event);

	lua_newtable(L);
	lua_pushstring(L, "task_id");
	lua_pushinteger(L, taskcfg->taskID);
	lua_settable(L, -3);
	int state = lua_resume(L, NULL, 2, &nres);
	if (state == LUA_YIELD) {
		return;
	}
	else if (state != LUA_OK) {
		sheeps_report(L, state, this, __func__, __LINE__);
	}
	reback_thread_state(this, L);
}

void LuaHookProtocol::event_agent(t_sheeps_agent* agent, const char* event) {
	int nres = 0;
	lua_State* L = get_thread_state(this);
	lua_getglobal(L, "event_handler");
	lua_pushstring(L, event);

	lua_newtable(L);
	lua_pushstring(L, "ip");
	lua_pushstring(L, agent->ip);
	lua_settable(L, -3);

	lua_pushstring(L, "port");
	lua_pushinteger(L, agent->port);
	lua_settable(L, -3);

	int state = lua_resume(L, NULL, 2, &nres);
	if (state == LUA_YIELD) {
		return;
	}
	else if (state != LUA_OK) {
		sheeps_report(L, state, this, __func__, __LINE__);
	}
	reback_thread_state(this, L);
}

void LuaHookProtocol::proxy_connect_open_hook(const char* proxy_type, const char* host, HSOCKET from, HSOCKET to, PROTOCOL protocol){
	lua_newtable(GL);

	lua_pushstring(GL, "type");
	lua_pushstring(GL, proxy_type);
	lua_settable(GL, -3);

	lua_pushstring(GL, "host");
	lua_pushstring(GL, host);
	lua_settable(GL, -3);

	lua_pushstring(GL, "client");
	lua_pushlightuserdata(GL, from);
	lua_settable(GL, -3);

	lua_pushstring(GL, "server");
	lua_pushlightuserdata(GL, to);
	lua_settable(GL, -3);

	lua_pushstring(GL, "protocol");
	lua_pushinteger(GL, protocol);
	lua_settable(GL, -3);

	char addr[40];
	int port;
	HsocketPeerAddr(from, addr, sizeof(addr), &port);
	lua_pushstring(GL, "client_ip");
	lua_pushstring(GL, addr);
	lua_settable(GL, -3);

	lua_pushstring(GL, "client_port");
	lua_pushinteger(GL, port);
	lua_settable(GL, -3);

	HsocketPeerAddr(to, addr, sizeof(addr), &port);
	lua_pushstring(GL, "server_ip");
	lua_pushstring(GL, addr);
	lua_settable(GL, -3);

	lua_pushstring(GL, "server_port");
	lua_pushinteger(GL, port);
	lua_settable(GL, -3);

	int ref = luaL_ref(GL, LUA_REGISTRYINDEX);
	this->ProxyRef.insert(std::make_pair(from, ref));

	lua_getglobal(GL, CONNECT_OPEN_HOOK_FUNC);
	lua_rawgeti(GL, LUA_REGISTRYINDEX, ref);
	int state = lua_pcall(GL, 1, 0, 0);
	if (state != LUA_OK) {
		printf("%s:%d %s\n", __func__, __LINE__, lua_tostring(GL, -1));
		lua_pop(GL, 1);
	}
}

void LuaHookProtocol::proxy_connect_faile_hook(const char* proxy_type, const char* host, HSOCKET from, HSOCKET to, PROTOCOL protocol){
	lua_getglobal(GL, CONNECT_FAILE_HOOK_FUNC);
	lua_newtable(GL);

	lua_pushstring(GL, "type");
	lua_pushstring(GL, proxy_type);
	lua_settable(GL, -3);

	lua_pushstring(GL, "host");
	lua_pushstring(GL, host);
	lua_settable(GL, -3);

	lua_pushstring(GL, "client");
	lua_pushlightuserdata(GL, NULL);
	lua_settable(GL, -3);

	lua_pushstring(GL, "server");
	lua_pushlightuserdata(GL, NULL);
	lua_settable(GL, -3);

	lua_pushstring(GL, "protocol");
	lua_pushinteger(GL, protocol);
	lua_settable(GL, -3);

	char addr[40];
	int port;
	HsocketPeerAddr(from, addr, sizeof(addr), &port);
	lua_pushstring(GL, "client_ip");
	lua_pushstring(GL, addr);
	lua_settable(GL, -3);

	lua_pushstring(GL, "client_port");
	lua_pushinteger(GL, port);
	lua_settable(GL, -3);

	HsocketPeerAddr(to, addr, sizeof(addr), &port);
	lua_pushstring(GL, "server_ip");
	lua_pushstring(GL, addr);
	lua_settable(GL, -3);

	lua_pushstring(GL, "server_port");
	lua_pushinteger(GL, port);
	lua_settable(GL, -3);

	int state = lua_pcall(GL, 1, 0, 0);
	if (state != LUA_OK) {
		printf("%s:%d %s\n", __func__, __LINE__, lua_tostring(GL, -1));
		lua_pop(GL, 1);
	}
}

void LuaHookProtocol::proxy_connect_close_hook(const char* proxy_type, HSOCKET from, HSOCKET to, PROTOCOL protocol){
	std::map<void*, int>::iterator iter = ProxyRef.find(from);
	if (iter != ProxyRef.end()){
		int ref = iter->second;
		this->ProxyRef.erase(from);

		lua_getglobal(GL, CONNECT_CLOSE_HOOK_FUNC);
		lua_rawgeti(GL, LUA_REGISTRYINDEX, ref);

		if (!from){
			lua_pushstring(GL, "client");
			lua_pushlightuserdata(GL, from);
			lua_settable(GL, -3);
		}
		
		if (!to){
			lua_pushstring(GL, "server");
			lua_pushlightuserdata(GL, to);
			lua_settable(GL, -3);
		}

		int state = lua_pcall(GL, 1, 0, 0);
		if (state != LUA_OK) {
			printf("%s:%d %s\n", __func__, __LINE__, lua_tostring(GL, -1));
			lua_pop(GL, 1);
		}
		luaL_unref(GL, LUA_REGISTRYINDEX, ref);
	}
}

int LuaHookProtocol::proxy_connect_send_hook(const char* proxy_type, HSOCKET from, HSOCKET to, PROTOCOL protocol, const char* data, int len){
	int ret = 0;
	std::map<void*, int>::iterator iter = ProxyRef.find(from);
	if (iter != ProxyRef.end()){
		int ref = iter->second;

		lua_getglobal(GL, CONNECT_SEND_HOOK_FUNC);
		lua_rawgeti(GL, LUA_REGISTRYINDEX, ref);

		if (!from){
			lua_pushstring(GL, "client");
			lua_pushlightuserdata(GL, from);
			lua_settable(GL, -3);
		}
		
		if (!to){
			lua_pushstring(GL, "server");
			lua_pushlightuserdata(GL, to);
			lua_settable(GL, -3);
		}
		
		lua_pushlstring(GL, data, len);
		int state = lua_pcall(GL, 2, 1, 0);
		if (state != LUA_OK) {
			printf("%s:%d %s\n", __func__, __LINE__, lua_tostring(GL, -1));
			lua_pop(GL, 1);
			return -1;
		}
		ret = (int)lua_tointeger(GL, -1);
		lua_pop(GL, 1);
	}
	return ret;
}

int LuaHookProtocol::proxy_connect_recv_hook(const char* proxy_type, HSOCKET from, HSOCKET to, PROTOCOL protocol, const char* data, int len){
	int ret = 0;
	std::map<void*, int>::iterator iter = ProxyRef.find(from);
	if (iter != ProxyRef.end()){
		int ref = iter->second;

		lua_getglobal(GL, CONNECT_RECV_HOOK_FUNC);
		lua_rawgeti(GL, LUA_REGISTRYINDEX, ref);

		if (!from){
			lua_pushstring(GL, "client");
			lua_pushlightuserdata(GL, from);
			lua_settable(GL, -3);
		}
		
		if (!to){
			lua_pushstring(GL, "server");
			lua_pushlightuserdata(GL, to);
			lua_settable(GL, -3);
		}

		lua_pushlstring(GL, data, len);
		int state = lua_pcall(GL, 2, 1, 0);
		if (state != LUA_OK) {
			printf("%s:%d %s\n", __func__, __LINE__, lua_tostring(GL, -1));
			lua_pop(GL, 1);
			return -1;
		}
		ret = (int)lua_tointeger(GL, -1);
		lua_pop(GL, 1);
	}
	return ret;
}

void LuaHookProtocol::timer_out(){
	time_t time_now = NOWTIME;
	if (time_now - time_last_gc > 10) {
		lua_gc(GL, LUA_GCCOLLECT, 0);
		time_last_gc = time_now;
	}
}